from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv, find_dotenv

_ = load_dotenv(find_dotenv())  # 读取本地 .env 文件，里面定义了 OPENAI_API_KEY

# 模型
# llm = ChatOpenAI(temperature=0, model="gpt-4")
llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo")

prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "Write out the following equation using algebraic symbols then solve it. Use the format\n\nEQUATION:...\nSOLUTION:...\n\n",
        ),
        ("human", "{equation_statement}"),
    ]
)

runnable = (
    {"equation_statement": RunnablePassthrough()}
    | prompt
    | llm
    # | llm.bind(stop="SOLUTION") # 使用 Runnable.bind() 将常量参数绑定到 Runnable 对象上
    | StrOutputParser()
)

print(runnable.invoke("x raised to the third plus seven equals 12"))